Skip to content

BUG: fix torch.result_type cross-kind promotion #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2023

Conversation

lucascolley
Copy link
Member

Reference comment: scipy/scipy#19051 (comment) @rgommers

Expected behaviour: array_api_compat.torch.result_type carries out cross-kind promotion like torch.result_type.

Observed behaviour:

In [1]: import torch

In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T

In [3]: from array_api_compat import array_namespace

In [4]: xp = array_namespace(t)

In [5]: xp.result_type(t, xp.float64)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 xp.result_type(t, xp.float64)

File ~/dev/array-api-compat/array_api_compat/torch/_aliases.py:136, in result_type(*arrays_and_dtypes)
    131     return _promotion_table[xdt, ydt]
    133 # This doesn't result_type(dtype, dtype) for non-array API dtypes
    134 # because torch.result_type only accepts tensors. This does however, allow
    135 # cross-kind promotion.
--> 136 return torch.result_type(x, y)

TypeError: result_type() received an invalid combination of arguments - got (Tensor, torch.dtype), but expected one of:
 * (Tensor tensor, Tensor other)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Number scalar, Tensor tensor)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Tensor tensor, Number other)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)
 * (Number scalar1, Number scalar2)
      didn't match because some of the arguments have invalid types: (Tensor, torch.dtype)

Improved behaviour on this branch:

In [1]: import torch

In [2]: t = torch.tensor([[0, 2], [1, 1], [2, 0]]).T

In [3]: from array_api_compat import array_namespace

In [4]: xp = array_namespace(t)

In [5]: xp.result_type(t, xp.float64)
Out[5]: torch.float64

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks right to me, and matches with the comment above. I'll wait for @asmeurer to have a look as well.

@rgommers
Copy link
Member

rgommers commented Sep 6, 2023

Okay, this is almost certainly correct and it has been open for a week. So I'll hit the green button here. Thanks @lucascolley!

@rgommers rgommers merged commit f047068 into data-apis:main Sep 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants